import os
import glob
import numpy as np
import matplotlib.pyplot as plt
import concurrent.futures
import csv

def read_vtf_coordinates(vtf_file):
    frames = []
    with open(vtf_file, 'r') as f:
        coords = []
        for line in f:
            if line.startswith("timestep"):
                if coords:
                    frames.append(np.array(coords))
                    coords = []
            elif line and not line.startswith("#") and not line.startswith("atom"):
                parts = line.strip().split()
                if len(parts) == 3:
                    coords.append([float(x) for x in parts])
        if coords:
            frames.append(np.array(coords))
    return frames

def center_of_mass(coords, indices):
    return np.mean(coords[indices], axis=0)

def calc_distance(com1, com2):
    return np.linalg.norm(com1 - com2)

def calc_dihedral(p1, p2, p3, p4):
    v12 = p2 - p1
    v23 = p3 - p2
    v34 = p4 - p3

    n1 = np.cross(v12, v23)
    n2 = np.cross(v23, v34)

    # Protection against zero-length normals
    n1_len = np.linalg.norm(n1)
    n2_len = np.linalg.norm(n2)
    if n1_len == 0.0 or n2_len == 0.0:
        return 0.0  # or np.nan, depending on your preference

    n1 = n1 / n1_len
    n2 = n2 / n2_len
    v23_norm = v23 / np.linalg.norm(v23)

    cos_angle = np.clip(np.dot(n1, n2), -1.0, 1.0)
    angle_deg = np.degrees(np.arccos(cos_angle))
    sign = np.dot(np.cross(n1, n2), v23_norm)
    if sign < 0:
        angle_deg = -angle_deg
    elif sign == 0:
        angle_deg = 180.0
    return angle_deg

def analyze_vtf_file_dual_chain(vtf_file, regions_chainA, regions_chainB):
    ''' Analyze a VTF file for both chains; returns distances and dihedrals merged for both. '''
    all_distances = []
    all_dihedrals = []
    frames = read_vtf_coordinates(vtf_file)
    for coords in frames:
        # Chain A analysis
        comA1 = center_of_mass(coords, regions_chainA[0])
        comA2 = center_of_mass(coords, regions_chainA[1])
        comA3 = center_of_mass(coords, regions_chainA[2])
        comA4 = center_of_mass(coords, regions_chainA[3])
        distA = calc_distance(comA1, comA4)
        dihA = calc_dihedral(comA1, comA2, comA3, comA4)
        all_distances.append(distA)
        all_dihedrals.append(dihA)
        # Chain B analysis
        comB1 = center_of_mass(coords, regions_chainB[0])
        comB2 = center_of_mass(coords, regions_chainB[1])
        comB3 = center_of_mass(coords, regions_chainB[2])
        comB4 = center_of_mass(coords, regions_chainB[3])
        distB = calc_distance(comB1, comB4)
        dihB = calc_dihedral(comB1, comB2, comB3, comB4)
        all_distances.append(distB)
        all_dihedrals.append(dihB)
    return all_distances, all_dihedrals

# ===== USER INPUT =====
folder = './VTFs'   # <--- your folder
# ---- Define indices for each region and each chain ----
# Chain A indices
region1A_indices = range(0, 1195)
region2A_indices = range(1196, 2345)
region3A_indices = range(2480, 3610)
region4A_indices = range(3611, 4850)
# Chain B indices
region1B_indices = range(4851, 6046)    # Adjust by offset! Example: N_atoms per chain = 971
region2B_indices = range(6047, 7196)
region3B_indices = range(7331, 8461)
region4B_indices = range(8461, 9701)
# Combine for convenience
regions_chainA = [region1A_indices, region2A_indices, region3A_indices, region4A_indices]
regions_chainB = [region1B_indices, region2B_indices, region3B_indices, region4B_indices]
output_data_csv = 'distance_dihedral_data.csv'
output_data_npy = 'distance_dihedral_data.npz'
num_workers = 16

vtf_files = glob.glob(os.path.join(folder, "*.vtf"))
print(f"Found {len(vtf_files)} VTF files.")

all_distances = []
all_dihedrals = []

# ------ Parallel Analysis of Both Chains ------
with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor:
    futures = []
    for vtf_file in vtf_files:
        future = executor.submit(
            analyze_vtf_file_dual_chain,
            vtf_file,
            regions_chainA,
            regions_chainB
        )
        futures.append(future)
    for future in concurrent.futures.as_completed(futures):
        local_distances, local_dihedrals = future.result()
        all_distances.extend(local_distances)
        all_dihedrals.extend(local_dihedrals)

x = np.array(all_distances)
y = np.array(all_dihedrals)

# ---- Save to CSV and NPY (merged results for both chains) ----
with open(output_data_csv, 'w', newline='') as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(['distance', 'dihedral'])
    for dist, dih in zip(x, y):
        writer.writerow([dist, dih])
np.savez(output_data_npy, distance=x, dihedral=y)

print(f"Saved merged chain analysis to {output_data_csv} (CSV) and {output_data_npy} (npy).")

###################################################### To Plot: ###################################
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

# Load the first dataset (.npz)
data = np.load('distance_dihedral_data.npz')
x = data['distance']
y = data['dihedral']

# Load the second dataset (.csv)
df2 = pd.read_csv('all_atom_MD_measurements.csv')
x2, y2 = df2['distance'].values, df2['dihedral'].values

# Determine unified plot limits for both datasets
min_x = min(x.min(), x2.min())
max_x = max(x.max(), x2.max())
min_y = min(y.min(), y2.min())
max_y = max(y.max(), y2.max())

plt.figure(figsize=(8, 6))

# Set axis range before plotting
plt.xlim(min_x, max_x)
plt.ylim(min_y, max_y)

# Set plot background color
plt.gca().set_facecolor('lightgray')

# Plot first dataset as hexbin (Upside), with extent
hb1 = plt.hexbin(
    x, y,
    gridsize=60,
    cmap='cool',
    alpha=0.7,
    extent=(min_x, max_x, min_y, max_y),
    mincnt=1,
    zorder=1,
    label='Upside'
)

# Optionally, colorbar for first hexbin
cbar1 = plt.colorbar(hb1, label='Upside Counts')
cbar1.ax.tick_params(labelsize=8)

# Plot second dataset as hexbin (All Atom), with the same extent
hb2 = plt.hexbin(
    x2, y2,
    gridsize=60,
    cmap='afmhot_r',
    alpha=0,
    extent=(min_x, max_x, min_y, max_y),
    mincnt=1,
    zorder=2,
    label='All Atom'
)

# Optionally, colorbar for second hexbin (might be repetitive)
cbar2 = plt.colorbar(hb2, label='All Atom Counts')
cbar2.ax.tick_params(labelsize=8)

plt.xlabel('D1-D4 Distance (Å)')
plt.ylabel('Dihedral Angle (°)')
plt.tight_layout()
plt.show()

